package com.twitter.finagle.serverset2

import com.twitter.conversions.time._
import com.twitter.concurrent.NamedPoolThreadFactory
import com.twitter.finagle.stats.{FinagleStatsReceiver, Stat}
import com.twitter.finagle.{Addr, Address}
import com.twitter.finagle.addr.WeightedAddress
import com.twitter.util._

/**
 * An Epoch is a Event that notifies its listener
 * once per `period`
 */
private[serverset2] class Epoch(val event: Event[Unit], val period: Duration)

private[serverset2] object Epoch {
  private val epochTimer: Timer = new ScheduledThreadPoolTimer(
    poolSize = 1,
    new NamedPoolThreadFactory("finagle-serversets Stabilizer timer", /*makeDaemons = */ true)
  )

  private val notifyMs: Stat =
    FinagleStatsReceiver.scope("serverset2", "stabilizer").stat("notify_ms")

  /** Create an event of epochs for the given duration. */
  def apply(period: Duration, timer: Timer = epochTimer): Epoch =
    new Epoch(
      new Event[Unit] {
        // accommodate Timers that have a minimum floor.
        private[this] val schedulingPeriod = period.max(1.millisecond)

        def register(w: Witness[Unit]): Closable = {
          timer.schedule(schedulingPeriod) {
            val elapsed = Stopwatch.start()
            w.notify(())
            notifyMs.add(elapsed().inMilliseconds)
          }
        }
      },
      period
    )

}

private[serverset2] object Stabilizer {

  // Used for delaying removals
  private case class State(limbo: Option[Set[Address]], active: Option[Set[Address]], addr: Addr)

  // Used for batching updates
  private case class States(publish: Option[Addr], last: Addr, next: Option[Addr], lastEmit: Time)

  private val initState = State(None, None, Addr.Pending)

  /**
   * Stabilize the address relative to the supplied source of removalEpochs,
   * such that any removed socket address in an Addr.Bound set is
   * placed in a limbo state until at least one removalEpoch has passed.
   * Also batches all updates (adds and removes) to trigger at most once
   * per batchEpoch.
   *
   * In practice, the source of removalEpochs must correlate with a failure
   * detection interval; we consider an address dead if it has not
   * been observed for at least one removalEpoch, and no failures
   * (Addr.Failed) have been observed in the same interval.
   *
   * All changes, added and removed socket addresses, are batched by
   * one update per batchEpoch, delaying adds by at most one batchEpoch.
   */
  def apply(va: Var[Addr], removalEpoch: Epoch, batchEpoch: Epoch): Var[Addr] =
    Var.async(Addr.Pending: Addr) { u =>
      // We construct an Event[State] representing states after
      // successive address observations. The state contains two sets
      // of addresses: the "active" set is the set of all observed
      // addresses in the current removalEpoch; the "limbo" set is active set
      // at the turn of the last removalEpoch.
      //
      // Whenever a failure is observed, the limbo set promoted by
      // adding it to the active set; thus we can guarantee that set
      // limbo++active contains all addresses seen in at least one
      // removalEpoch's period without intermittent failure.
      //
      // Our state also maintains the last observed value of `va`.
      //
      // Thus we interpret the stabilized address to be
      // Addr.Bound(limbo++active) when these are nonempty; otherwise
      // the last observed address.
      //
      // The updates to this stabilized address are then batched and
      // triggered at most once per batchEpoch.
      val addrOrEpoch: Event[Either[Addr, Unit]] =
        if (removalEpoch.period == Duration.Zero) {
          // If removalEpoch is 0 seconds then removals are not delayed but other
          // behavior is preserved. For instance, we still want to cache last
          // seen addresses in the case of failure. Do this efficiently by
          // triggering epochs manually rather than spinning a timer. Trigger
          // two epochs to ensure a complete epoch period has "passed", expiring
          // the existing limbo set.
          new Event[Either[Addr, Unit]] {
            def register(s: Witness[Either[Addr, Unit]]): Closable = {
              va.changes.respond { addr =>
                s.notify(Left(addr))
                s.notify(Right(()))
                s.notify(Right(()))
              }
            }
          }
        } else va.changes.select(removalEpoch.event)

      val states: Event[State] = addrOrEpoch.foldLeft(initState) {
        // Addr update
        case (st @ State(limbo, active, last), Left(addr)) =>
          addr match {
            case Addr.Failed(_) =>
              State(None, Some(active.getOrElse(Set.empty) ++ limbo.getOrElse(Set.empty)), addr)

            case Addr.Bound(bound, _) =>
              State(limbo, Some(merge(active.getOrElse(Set.empty), bound)), addr)

            case Addr.Neg if (active == None && limbo == None) =>
              State(limbo, Some(Set.empty), addr)

            case addr =>
              // Any other address simply propagates the address while
              // leaving the active/limbo set unchanged. Both active
              // and limbo have to expire in order for this address to
              // propagate.
              st.copy(addr = addr)
          }

        // The removalEpoch turned
        case (st @ State(limbo, active, last), Right(())) =>
          last match {
            case Addr.Bound(bound, _) =>
              State(active, Some(bound), last)

            case Addr.Neg =>
              State(active, Some(Set.empty), Addr.Neg)

            case Addr.Pending | Addr.Failed(_) =>
              // If the last address is nonbound, we ignore it and
              // maintain our state; we cannot demote the active set
              // when nonbound, since that would eventually promote
              // the address
              st
          }
      }

      val addrs = states.map {
        case State(limbo, active, last) =>
          val all = merge(limbo.getOrElse(Set.empty), active.getOrElse(Set.empty))
          if (all.nonEmpty) {
            Addr.Bound(all)
          } else if (limbo != None || active != None) {
            Addr.Neg
          } else {
            last
          }
      }

      // Trigger at most one change to state per batchEpoch
      val init = States(Some(Addr.Pending), Addr.Pending, None, Time.Zero)
      val batchedUpdates =
        if (batchEpoch.period == Duration.Zero) {
          addrs
        } else {
          addrs
            .select(batchEpoch.event)
            .foldLeft(init) {
              case (st, ev) =>
                val now = Time.now
                ev match {
                  case Left(newAddr) =>
                    // There's a change to the serverset but it's not different. Noop
                    if (newAddr == st.last)
                      st.copy(publish = None)
                    // There's a change to the serverset, but we have published in < batchEpoch. Hold change.
                    else if (now - st.lastEmit < batchEpoch.period)
                      st.copy(publish = None, next = Some(newAddr))
                    // There's a change to the serverset and we haven't published in >= batchEpoch. Publish.
                    else
                      States(Some(newAddr), newAddr, None, now)

                  case Right(_) =>
                    st.next match {
                      // Epoch turned, but we have published in < batchEpoch. Noop
                      case _ if now - st.lastEmit < batchEpoch.period =>
                        st.copy(publish = None)
                      // Epoch turned but there is no next state. Noop
                      case None =>
                        st.copy(publish = None)
                      // Epoch turned, there's a next state, and we haven't published in >= batchEpoch. Publish.
                      case Some(next) =>
                        States(Some(next), next, None, now)
                    }
                }
            }
            .collect {
              case States(Some(publish), _, _, _) => publish
            }
        }

      batchedUpdates.register(Witness(u))
    }

  /**
   * Merge WeightedSocketAddresses with same underlying SocketAddress
   * preferring weights from `next` over `prev`.
   */
  private def merge(prev: Set[Address], next: Set[Address]): Set[Address] = {
    val nextUnweighted = next.map(WeightedAddress.extract(_)._1)

    val legacy = prev.filter { addr =>
      val (unweighted, _) = WeightedAddress.extract(addr)
      !nextUnweighted.contains(unweighted)
    }

    legacy ++ next
  }
}
